{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "authorship_tag": "ABX9TyNYrI4hHEz/+IX9KImm3ba2",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/RQLuo/MixTeX/blob/main/MixTex_Demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Model Import and Inference"
      ],
      "metadata": {
        "id": "7-ASTSmn3UpZ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "kCELUXkedivU",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "492370ee-5fce-4843-85b7-d4928cc0c554"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor\n",
        "from PIL import Image\n",
        "import requests\n",
        "\n",
        "feature_extractor = AutoImageProcessor.from_pretrained(\"MixTex/ZhEn-Latex-OCR\")\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"MixTex/ZhEn-Latex-OCR\", max_len=296)\n",
        "model = VisionEncoderDecoderModel.from_pretrained(\"MixTex/ZhEn-Latex-OCR\")\n",
        "#imgen = Image.open(requests.get('https://cdn-uploads.huggingface.co/production/uploads/62dbaade36292040577d2d4f/eOAym7FZDsjic_8ptsC-H.png', stream=True).raw)\n",
        "#print(tokenizer.decode(model.generate(feature_extractor(imgen, return_tensors=\"pt\").pixel_values)[0]).replace('\\\\[','\\\\begin{align*}').replace('\\\\]','\\\\end{align*}'))"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Dataset import and model training"
      ],
      "metadata": {
        "id": "wruZtc2v3pMo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install datasets accelerate"
      ],
      "metadata": {
        "id": "lDf3-mX4OSiV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
        "from datasets import load_dataset\n",
        "from PIL import Image\n",
        "import torch\n",
        "\n",
        "dataframe = load_dataset(\"MixTex/Pseudo-Latex-ZhEn-1\")"
      ],
      "metadata": {
        "id": "bBwBcLuoZD7C"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.utils.data import Dataset\n",
        "class MixTexDataset(Dataset):\n",
        "    def __init__(self, dataframe, tokenizer, feature_extractor, max_length=256):\n",
        "        self.dataframe = dataframe\n",
        "        self.tokenizer = tokenizer\n",
        "        self.max_length = max_length\n",
        "        self.feature_extractor = feature_extractor\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.dataframe['train'])\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image = self.dataframe['train'][idx]['image'].convert(\"RGB\")\n",
        "        target_text = self.dataframe['train'][idx]['text']\n",
        "        pixel_values = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n",
        "        target = self.tokenizer(target_text, padding=\"max_length\", max_length=self.max_length, truncation=True).input_ids\n",
        "        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in target]\n",
        "        return {\"pixel_values\": pixel_values.squeeze(), \"labels\": torch.tensor(labels)}\n",
        "traindataset = MixTexDataset(dataframe, tokenizer, feature_extractor= feature_extractor)"
      ],
      "metadata": {
        "id": "Os2DH0vIiJI2"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "training_args = Seq2SeqTrainingArguments(\n",
        "    output_dir=\"./results\",\n",
        "    per_device_train_batch_size=12,\n",
        "    predict_with_generate=True,\n",
        "    logging_dir='./logs',\n",
        "    learning_rate=5e-5,\n",
        "    save_total_limit=1,\n",
        "    logging_steps=100,\n",
        "    save_steps=500,\n",
        "    num_train_epochs=3,\n",
        "    fp16=True,\n",
        ")\n",
        "\n",
        "trainer = Seq2SeqTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=traindataset,\n",
        ")"
      ],
      "metadata": {
        "id": "Ec0vj-4_iGtG"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 144
        },
        "id": "yhkY3Ok_nUwR",
        "outputId": "4058e96f-b64d-4687-c679-278a977b01ba"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='165' max='29436' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [  165/29436 01:47 < 5:21:08, 1.52 it/s, Epoch 0.02/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>100</td>\n",
              "      <td>0.195300</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}